#!/usr/bin/python

### Largely self-contained image binarization and deskewing.  You can easily 
### use this as a basis for other kinds of preprocessing.

import sys,os,re,optparse,shutil,glob,fcntl
import signal
signal.signal(signal.SIGINT,lambda *args:sys.exit(1))
import traceback
import argparse
import multiprocessing

import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from matplotlib import patches

# all the image processing code comes from scipy and pylab
from pylab import *
from scipy.stats.stats import trim1
from multiprocessing import Pool
from scipy.ndimage import measurements,interpolation
from scipy.misc import imsave
from pylab import gray as gray_

# ocrolib is only used for image I/O and pathname manipulation
import ocrolib
# ocrorast is used for estimating the skew angle
import ocrorast

parser = argparse.ArgumentParser(description = """
%prog -o dir [options] image1 image2 ...

Perform document image preprocessing:

- Sauvola binarization
- deskewing
- large and small component removal

Images are processed from the command line and put into a standard book directory,
creating

- book/0001.png (deskewed grayscale page image)
- book/0001.bin.png (deskewed and cleaned binary page image)

This assumes 300dpi images (i.e., all the internal thresholds and constants
are set up for that).  If your image is a different resolution, use the -z (zoom)
argument.
""")

parser.add_argument("args",default=[],nargs='*',help="input lines")

parser.add_argument("-o","--output",help="output directory",default="book")
parser.add_argument("-O","--Output",help="output image.png to image.bin.png (in place)",action="store_true")
parser.add_argument("-d","--display",help="display result",action="store_true")
parser.add_argument("-D","--Display",help="display continuously",action="store_true")
parser.add_argument("-q","--silent",action="store_true",help="disable warnings")
parser.add_argument("-Q","--parallel",type=int,default=multiprocessing.cpu_count(),help="number of parallel processes to use")
parser.add_argument("-g","--gtextension",help="ground truth extension for copying in ground truth (include all dots)",default=None)
parser.add_argument("--debug",action="store_true")

# parser.add_argument("--dpi",default=300,type=float,help="resolution (DPI) (300)")
parser.add_argument("-z","--zoom",type=float,default=1.0,help="rescale the image prior to processing")
parser.add_argument("--maxsize",type=int,default=300,help="maximum character component size")
parser.add_argument("--minsize",type=int,default=5,help="minimum character component size")
parser.add_argument("--binarize",action="store_true",help="always run binarization, even if image appears binary already")
parser.add_argument("--invert",action="store_true",help="invert the image prior to binarization")
parser.add_argument("--htrem",action="store_true",help="always remove halftones (even if there don't appear to be any)")
parser.add_argument("--nohtrem",action="store_true",help="never remove halftones (even if there appear to be some)")
parser.add_argument("--uncleaned",action="store_true",help="output only the deskewed binary image with no further cleanup")
parser.add_argument("--noskew",action="store_true",help="do not perform skew correction")

# sauvola
parser.add_argument("-s","--sigma",type=float,default=150,help="sigma arguent for Sauvola binarization")
parser.add_argument("-k","--k",type=float,default=0.3,help="k value for Sauvola binarization")

# hysteresis thresholding
# TBD

args = parser.parse_args()

if len(args.args)<1:
    parser.print_help()
    sys.exit(0)

if args.debug or args.display or args.Display: args.parallel = 1

################################################################
# preprocessing
################################################################

import os,os.path
from pylab import *
from scipy.ndimage import measurements,interpolation,filters,morphology
import math

def dread(fname):
    """Read an image, similar to imread.  However, imread flips JPEG images;
    this fixes that."""
    _,ext = os.path.splitext(fname)
    image = imread(fname)
    if ext.lower() in [".jpg",".jpeg"]:
        image = image[::-1,:,...]
    return image

################################################################
### Binarization
################################################################

def is_binary(image):
    """Check whether an input image is binary"""
    return sum(image==amin(image))+sum(image==amax(image)) > 0.99*image.size
    
def gsauvola(image,sigma=150.0,R=None,k=0.3,filter='uniform',scale=2.0):
    """Perform Sauvola-like binarization.  This uses linear filters to
    compute the local mean and variance at every pixel."""
    if image.dtype==dtype('uint8'): image = image / 256.0
    if len(image.shape)==3: image = mean(image,axis=2)
    if filter=="gaussian":
        filter = filters.gaussian_filter
    elif filter=="uniform":
        filter = filters.uniform_filter
    else:
        pass
    scaled = interpolation.zoom(image,1.0/scale,order=0,mode='nearest')
    s1 = filter(ones(scaled.shape),sigma)
    sx = filter(scaled,sigma)
    sxx = filter(scaled**2,sigma)
    avg_ = sx / s1
    stddev_ = maximum(sxx/s1 - avg_**2,0.0)**0.5
    s0,s1 = avg_.shape
    s0 = int(s0*scale)
    s1 = int(s1*scale)
    avg = zeros(image.shape)
    interpolation.zoom(avg_,scale,output=avg[:s0,:s1],order=0,mode='nearest')
    stddev = zeros(image.shape)
    interpolation.zoom(stddev_,scale,output=stddev[:s0,:s1],order=0,mode='nearest')
    if R is None: R = amax(stddev)
    thresh = avg * (1.0 + k * (stddev / R - 1.0))
    return array(255*(image>thresh),'uint8')

def inverse(image):
    return amax(image)-image

def autoinvert(image):
    """Automatically invert document images, so that the majority of pixels
    (background pixels) are black."""
    if median(image)>mean([amax(image),amin(image)]):
        image = amax(image)-image
    return image

################################################################
### Bounding-box operations.
################################################################

def bounding_boxes_math(image):
    """Compute the bounding boxes in the image; returns mathematical
    coordinates."""
    image = (image>mean([amax(image),amin(image)]))
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    result = []
    h,w = image.shape
    for o in objects:
        y1 = h-o[0].start
        y0 = h-o[0].stop
        x0 = o[1].start
        x1 = o[1].stop
        c = (x0,y0,x1,y1)
        result.append(c)
    return result

def select_plausible_char_bboxes(bboxes,dpi=300.0):
    """Performs simple heuristic checks on character bounding boxes;
    removes boxes that are too small or too large, or have the wrong
    aspect ratio."""
    s = dpi/300.0
    result = []
    for b in bboxes:
        x0,y0,x1,y1 = b
        w = x1-x0
        if w<s*5: continue
        h = y1-y0
        if h<s*5: continue
        a = w*1.0/h
        if a>s or a<0.25: continue
        if w>s*100: continue
        if h>s*100: continue
        result.append(b)
    return result

def estimate_skew_angle(image):
    """Estimate the skew angle of a document image, by first finding
    character bounding boxes, then invoking the RAST text line finder
    (without constraints) in order to find the longest line."""
    assert is_binary(image)
    finder = ocrorast.make_TextLineRAST2()
    finder.setMaxLines(1)
    nobjects = 0
    bboxes = bounding_boxes_math(image)
    bboxes = select_plausible_char_bboxes(bboxes)
    assert bboxes!=[]
    for c in bboxes:
        finder.addChar(*c)
        nobjects += 1
    finder.compute()
    if finder.nlines()<1: return 0.0
    m = finder.getLine_m(0)
    del finder
    return math.atan(m)

def deskew(image):
    """Actually deskew an image by first estimating the skew angle, then
    performing the rotation."""
    a = estimate_skew_angle(image)
    return interpolation.rotate(image,-a*180/pi,mode='nearest',order=0)

def check_contains_halftones(image,dpi=300.0):
    """Heuristic method for determining whether we should apply a halftone removal
    algorithm."""
    bboxes = bounding_boxes_math(image)
    r = 4*dpi/300.0
    big = 0
    for b in bboxes:
        x0,y0,x1,y1 = b
        if x1-x0>r or y1-y0>r: big += 1
    return big<0.3*len(bboxes)

def remove_small_components(image,r=3):
    """Remove any connected components that are smaller in both dimension than r"""
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    for i in range(len(objects)):
        o = objects[i]
        if o[0].stop-o[0].start>r: continue
        if o[1].stop-o[1].start>r: continue
        c = image[o]
        c[c==i+1] = 0
    return (image!=0)

def remove_big_components(image,r=100):
    """Remove any connected components that are smaller in any dimension than r"""
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    for i in range(len(objects)):
        o = objects[i]
        if o[0].stop-o[0].start<r and o[1].stop-o[1].start<r: continue
        c = image[o]
        c[c==i+1] = 0
    return (image!=0)

def remove_small_any(image,r=3):
    """Remove both small connected components and small holes."""
    image = remove_small_components(image,r=r)
    image = amax(image)-image
    image = remove_small_components(image,r=r)
    image = amax(image)-image
    return image

def rectangular_cover(image,minsize=5):
    """Cover the set of regions with their bounding boxes.  This is
    an image-to-image transformation."""
    image,ncomponents = measurements.label(image)
    objects = measurements.find_objects(image)
    output = zeros(image.shape)
    for i in range(len(objects)):
        o = objects[i]
        if o[0].stop-o[0].start<minsize: continue
        if o[1].stop-o[1].start<minsize: continue
        output[o] = 1
    return output

def find_halftones(image,dpi=300.0,threshold=0.05,r=5,sigma=15.0,cover=1):
    """Find halftone regions in an image.  First, find small components and
    holes, then smooth their occurrences and threshold, finally compute
    a rectangular cover of the thresholded and smoothed image."""
    filtered = remove_small_any(image,r=r)
    diff = ((image!=0)!=(filtered!=0))
    density = filters.gaussian_filter(1.0*diff,sigma*dpi/300.0)
    if cover:
        return rectangular_cover(density>threshold)
    else:
        return maximum(diff,density>threshold)

def remove_halftones(image,dpi=300.0,threshold=0.05,r=5,sigma=15.0):
    """Perform halftone removal using find_halftones."""
    halftones = find_halftones(image,dpi=dpi,threshold=threshold,r=r,sigma=sigma)
    return maximum(image-amax(image)*halftones,0)

################################################################
### All preprocessing steps put together.
################################################################

def preprocess(raw,title=None):
    if args.debug:
        subplot(121); imshow(raw); ginput(1,0.001)
        if title is not None: xlabel(title)
        subplot(122)

    # zoom if requested

    if args.zoom!=1.0:
        raw = interpolation.zoom(raw,args.zoom,mode='nearest',order=1)
        if args.debug: 
            cla(); xlabel("zoomed")
            imshow(raw); ginput(1,9999)

    # binarize if the image isn't already binary

    if args.binarize or not is_binary(raw):
        bin = gsauvola(raw)
    else:
        bin = array(255*(raw>0.5*(amax(raw)+amin(raw))),'B')
    assert amax(bin)>amin(bin),"something went wrong with binarization"

    # invert or detect inverted scans

    if args.invert:
        raw = inverse(raw)
        bin = inverse(bin)

    if args.debug: 
        cla(); xlabel("binarized and inverted")
        imshow(bin); ginput(1,9999)

    # now clean up for skew estimation

    cleaned = bin

    if args.htrem or (not args.nohtrem and check_contains_halftones(bin)):
        cleaned = remove_halftones(cleaned)
        if args.debug: 
            cla(); xlabel("half tone removal")
            imshow(cleaned); ginput(1,9999)

    if args.minsize>0:
        cleaned = remove_small_components(cleaned,args.minsize)
        if args.debug: 
            cla(); xlabel("minsize filtering")
            imshow(cleaned); ginput(1,9999)

    if args.maxsize<10000:
        cleaned = remove_big_components(cleaned,args.maxsize)
        if args.debug: 
            cla(); xlabel("maxsize filtering")
            imshow(cleaned); ginput(1,9999)

    # perform skew estimation

    if not args.noskew:
        a = estimate_skew_angle(cleaned)
        if args.uncleaned: bin = orig
        gray = interpolation.rotate(raw,-a*180/pi,mode='nearest',order=0)
        bin = interpolation.rotate(bin,-a*180/pi,mode='nearest',order=0)
        bin = array(255*(bin>0.5*(amin(bin)+amax(bin))),'B')
        if args.debug: 
            cla(); xlabel("skew correction by %f"%a)
            imshow(bin); ginput(1,9999)
        return bin,gray
    else:
        return cleaned,raw

################################################################
### main loop
################################################################

if args.Display: args.display = 1
if args.display: ion(); show()
if args.Output:
    args.output = None
if not args.Output and os.path.exists(args.output):
    print "%s: already exists; please remove"%args.output
    sys.exit(0)

files = None
if not args.Output:
    os.mkdir(args.output)
    files = open(args.output+"/FILES","w")

def process_image(image,arg,count):
    if amax(image)<=amin(image)+1e-4:
        print arg,"is empty"
        return

    if args.display: 
        clf()
        imshow(image,cmap=cm.gray)
        draw()
        ginput(1,timeout=1)

    title = None
    if args.debug:
        ion(); clf(); gray_()
        title = "%s %s"%(arg,count)
    try:
        bin,gray = preprocess(image,title=title)
        print bin.shape,gray.shape
    except:
        print arg,count,"failed"
        traceback.print_exc()
        return
    if args.display: 
        clf()
        imshow(bin,cmap=cm.gray)
        draw()
        if not args.Display: 
            raw_input("hit ENTER to continue")
        else:
            ginput(1,timeout=1)
    if args.Output:
        dest,_ = ocrolib.allsplitext(arg)
        print "# writing",dest
        imsave(dest+".png",gray,cmap=cm.gray)
        imsave(dest+".bin.png",bin,cmap=cm.gray)
    else:
        dest = "%s/%04d" % (args.output,count)
        print "# writing",dest,gray.shape,bin.shape
        imsave(dest+".png",gray,cmap=cm.gray)
        imsave(dest+".bin.png",bin,cmap=cm.gray)
        if args.gtextension is not None:
            base,_ = ocrolib.allsplitext(arg)
            shutil.copyfile(base+args.gtextension,dest+".gt.txt")
        if files is not None:
            fcntl.flock(files,fcntl.LOCK_EX)
            files.write("%04d\t%s\n"%(count,arg))
            files.flush()
            fcntl.flock(files,fcntl.LOCK_UN)

def process1(t):
    arg,count = t
    n = 0
    for image,arg in ocrolib.page_iterator([arg]):
        assert n<2,"no multipage files with parallel processing; use -P 0"
        print "===",arg,count,image.shape
        try:
            process_image(image,arg,count)
        except:
            traceback.print_exc()
            raise
        n += 1

if args.parallel<2:
    count = 0
    for image,arg in ocrolib.page_iterator(args.args):
        print "===",arg,count,image.shape
        count += 1
        process_image(image,arg,count)
else:
    pool = Pool(processes=args.parallel)
    jobs = []
    for i in range(len(args.args)): jobs += [(args.args[i],i+1)]
    result = pool.map(process1,jobs)

if files is not None:
    files.close()
